{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AgiBot World Diffusion Policy Training Demo\n",
    "\n",
    "This notebook demonstrates how to use **AgiBotWorldDataset** to run an offline training workflow.\n",
    "Make sure you have installed all necessary packages before running.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================\n",
    "# 1. Imports and Parameter Settings\n",
    "# =============================================\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from lerobot.common.datasets.lerobot_dataset import LeRobotDataset\n",
    "from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig\n",
    "from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy\n",
    "\n",
    "# Parameters\n",
    "FPS = 30\n",
    "TASK_ID = 352\n",
    "training_steps = 5000\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Paths\n",
    "dataset_path = \"/path/to/your/AgiBotWorld/dataset\"\n",
    "output_path = \"/path/to/save/your/checkpoint\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================\n",
    "# 2. Dataset Setup\n",
    "# =============================================\n",
    "observation_idx = np.array([-1, 0])\n",
    "action_idx = np.arange(-1, 15)\n",
    "repo_id = f\"agibotworld/task_{TASK_ID}\"\n",
    "\n",
    "delta_timestamps = {\n",
    "    \"observation.images.top_head\": (observation_idx / FPS).tolist(),\n",
    "    \"observation.state\": (observation_idx / FPS).tolist(),\n",
    "    \"action\": (action_idx / FPS).tolist(),\n",
    "}\n",
    "\n",
    "dataset = LeRobotDataset(\n",
    "    repo_id=repo_id,\n",
    "    root=f\"{dataset_path}/{repo_id}\",\n",
    "    delta_timestamps=delta_timestamps,\n",
    "    local_files_only=True\n",
    ")\n",
    "\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    num_workers=0,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    "    pin_memory=(device.type == \"cuda\"),\n",
    "    drop_last=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to train one robot policy model to master multiple distinct skills, you can use ’MultiLeRobotDataset‘ to load datasets for various tasks into a unified training process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset\n",
    "repo_ids = [f\"agibotworld/{path.name}\" for path in Path(dataset_path).glob(\"agibotworld/task_*\")]\n",
    "multi_dataset = MultiLeRobotDataset(\n",
    "    repo_ids=repo_ids,\n",
    "    root=dataset_path,\n",
    "    delta_timestamps=delta_timestamps,\n",
    "    local_files_only=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's kick off a simple training with Diffusion Policy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================\n",
    "# 3. Policy Configuration and Initialization\n",
    "# =============================================\n",
    "cfg = DiffusionConfig()\n",
    "cfg.input_shapes = {\n",
    "    \"observation.images.top_head\": [3, 480, 640],\n",
    "    \"observation.state\": [20],\n",
    "}\n",
    "cfg.input_normalization_modes = {\n",
    "    \"observation.images.top_head\": \"mean_std\",\n",
    "    \"observation.state\": \"min_max\",\n",
    "}\n",
    "cfg.output_shapes = {\n",
    "    \"action\": [22],\n",
    "}\n",
    "\n",
    "policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)\n",
    "#policy = DiffusionPolicy(cfg, dataset_stats=multi_dataset.stats)\n",
    "policy.train()\n",
    "policy.to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================\n",
    "# 4. Training Loop\n",
    "# =============================================\n",
    "step = 0\n",
    "done = False\n",
    "\n",
    "while not done:\n",
    "    for batch in dataloader:\n",
    "        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}\n",
    "        output_dict = policy.forward(batch)\n",
    "        loss = output_dict[\"loss\"]\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        print(f\"Step {step}, Loss: {loss.item():.3f}\")\n",
    "        step += 1\n",
    "        \n",
    "        if step >= training_steps:\n",
    "            done = True\n",
    "            break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================\n",
    "# 5. Save Policy Checkpoint\n",
    "# =============================================\n",
    "policy.save_pretrained(output_path)\n",
    "print(f\"Model saved to {output_path}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Congrats! Now please feel free to explore the AgiBot World!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
